Skip to content

[water] Support variadic reduction ops in Water dialect and add corresponding simplification pass#1053

Merged
martin-luecke merged 2 commits intomainfrom
users/martin/multi_operand_reduce
Mar 6, 2026
Merged

[water] Support variadic reduction ops in Water dialect and add corresponding simplification pass#1053
martin-luecke merged 2 commits intomainfrom
users/martin/multi_operand_reduce

Conversation

@martin-luecke
Copy link
Contributor

Extends the Water dialect reduction ops (wave.sum, wave.max_element) to accept variadic inputs, matching the PyWave representation, where expand_graph tiles reduction inputs into a list of slices. This simplifies FX <-> MLIR roundtrips by allowing the dialect to directly represent the intermediate form, rather than requiring the Python side to decompose reductions before emission, track which reductions stem from this, and fuse them again for the roundtrip.

A new ExpandVariadicReductions pass chains N variadic inputs into N single-input reductions, each feeding its result as the next accumulator — a partial port of the logic in PyWave's decompose_reduce_ops pass. Both the Water emitter and FX importer have been updated to handle variadic forms in both directions.
A normal-form annotation for expanded reductions could be added to indicate where in the pipeline single-input reductions are expected, though currently this would only be relevant for codegen, I think.

Signed-off-by: Martin Lücke <martin.luecke@amd.com>
@martin-luecke martin-luecke requested a review from ftynse March 5, 2026 21:50
@martin-luecke martin-luecke changed the title [water] Support variadic reduction ops in Water dialect and corresponding simplification pass [water] Support variadic reduction ops in Water dialect and add corresponding simplification pass Mar 5, 2026
/// %1 = wave.sum %b init(%0) <scope>
/// %r = wave.sum %c init(%1) <scope>
template <typename ReductionOp>
struct ExpandVariadicReduction : public OpRewritePattern<ReductionOp> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a trait for reductions, would it make sense to make this OpTraitRewritePattern? Very open for arguments here since traits don't provide named accessors... Related discussion here #992 (comment).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea in principle, but, as you said, traits don't give us named accessors or the typed create() builder. With only two reduction ops and the template giving us full type safety, I think the explicit instantiation is the better tradeoff here. Of course, we have to eventually remember to add new types of reductions to the patterns.add call.
As mentioned in the related discussion, we also have the option to model this as an interface. I don't have a strong opinion here.

supports variadic inputs for faithful roundtripping with the Python
representation. This pass normalizes them before lowering, which
requires single-input reductions.
}];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you document the differences between this pass and its python counterpart?

Signed-off-by: Martin Lücke <martin.luecke@amd.com>
@martin-luecke martin-luecke merged commit 7ee7276 into main Mar 6, 2026
25 of 26 checks passed
@martin-luecke martin-luecke deleted the users/martin/multi_operand_reduce branch March 6, 2026 17:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants